DHEIVER commited on
Commit
5afbda2
·
verified ·
1 Parent(s): 159d853

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -1,19 +1,15 @@
1
  import gradio as gr
2
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
3
- import requests
4
- from PIL import Image
5
- import torch, os, re, json
6
- import spaces
7
 
 
8
  torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/74801584018932.png', 'chart_example_1.png')
9
  torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_2.png')
10
 
11
-
12
-
13
  model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma")
14
  processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
15
 
16
-
17
  @spaces.GPU
18
  def predict(image, input_text):
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -26,20 +22,25 @@ def predict(image, input_text):
26
 
27
  prompt_length = inputs['input_ids'].shape[1]
28
 
29
- # Generate
30
  generate_ids = model.generate(**inputs, max_new_tokens=512)
31
  output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
32
 
33
  return output_text
34
 
35
-
36
- image = gr.components.Image(type="pil", label="Chart Image")
37
- input_prompt = gr.components.Textbox(label="Input Prompt")
38
- model_output = gr.components.Textbox(label="Model Output")
39
- examples = [["chart_example_1.png", "Describe the trend of the mortality rates for children before age 5"],
40
- ["chart_example_2.png", "What is the share of respondants who prefer Facebook Messenger in the 30-59 age group?"]]
 
 
 
 
 
41
 
42
- title = "Interactive Gradio Demo for ChartGemma model"
43
  interface = gr.Interface(fn=predict,
44
  inputs=[image, input_prompt],
45
  outputs=model_output,
@@ -47,4 +48,4 @@ interface = gr.Interface(fn=predict,
47
  title=title,
48
  theme='gradio/soft')
49
 
50
- interface.launch()
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
3
+ import torch
 
 
 
4
 
5
+ # Baixando imagens de exemplo
6
  torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/74801584018932.png', 'chart_example_1.png')
7
  torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_2.png')
8
 
9
+ # Carregando modelo e processador
 
10
  model = PaliGemmaForConditionalGeneration.from_pretrained("ahmed-masry/chartgemma")
11
  processor = AutoProcessor.from_pretrained("ahmed-masry/chartgemma")
12
 
 
13
  @spaces.GPU
14
  def predict(image, input_text):
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
22
 
23
  prompt_length = inputs['input_ids'].shape[1]
24
 
25
+ # Geração
26
  generate_ids = model.generate(**inputs, max_new_tokens=512)
27
  output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
28
 
29
  return output_text
30
 
31
+ # Definindo os componentes da interface
32
+ image = gr.components.Image(type="pil", label="Imagem do Gráfico")
33
+ input_prompt = gr.components.Textbox(label="Texto de Entrada")
34
+ model_output = gr.components.Textbox(label="Saída do Modelo")
35
+
36
+ # Exemplos
37
+ examples = [["chart_example_1.png", "Descreva a tendência das taxas de mortalidade para crianças menores de 5 anos"],
38
+ ["chart_example_2.png", "Qual é a proporção de respondentes que preferem o Facebook Messenger no grupo etário de 30 a 59 anos?"]]
39
+
40
+ # Título da interface
41
+ title = "Demo Interativa do Modelo ChartGemma"
42
 
43
+ # Criando e lançando a interface
44
  interface = gr.Interface(fn=predict,
45
  inputs=[image, input_prompt],
46
  outputs=model_output,
 
48
  title=title,
49
  theme='gradio/soft')
50
 
51
+ interface.launch()