prithivMLmods commited on
Commit
239e8eb
Β·
verified Β·
1 Parent(s): b625ee3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -38
app.py CHANGED
@@ -28,40 +28,6 @@ aya_model = AutoModelForImageTextToText.from_pretrained(
28
  AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
29
  )
30
 
31
- def aya_vision_chat(image, text_prompt):
32
- # If image is provided as a URL, load it via requests.
33
- if isinstance(image, str):
34
- response = requests.get(image)
35
- image = Image.open(BytesIO(response.content))
36
-
37
- messages = [{
38
- "role": "user",
39
- "content": [
40
- {"type": "image", "image": image},
41
- {"type": "text", "text": text_prompt},
42
- ],
43
- }]
44
-
45
- inputs = aya_processor.apply_chat_template(
46
- messages,
47
- padding=True,
48
- add_generation_prompt=True,
49
- tokenize=True,
50
- return_dict=True,
51
- return_tensors="pt"
52
- ).to(aya_model.device)
53
-
54
- gen_tokens = aya_model.generate(
55
- **inputs, max_new_tokens=300, do_sample=True, temperature=0.3
56
- )
57
-
58
- # Decode only the newly generated tokens.
59
- response_text = aya_processor.tokenizer.decode(
60
- gen_tokens[0][inputs.input_ids.shape[1]:],
61
- skip_special_tokens=True
62
- )
63
- return response_text
64
-
65
  @spaces.GPU
66
  def model_inference(input_dict, history):
67
  text = input_dict["text"].strip()
@@ -77,9 +43,40 @@ def model_inference(input_dict, history):
77
  # For simplicity, use the first provided image.
78
  image = load_image(files[0])
79
  yield "Processing with Aya-Vision β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–’β–’β–’ 69%"
80
- response_text = aya_vision_chat(image, text_prompt)
81
- yield response_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  return
 
83
  # Load images if provided.
84
  if len(files) > 1:
85
  images = [load_image(image) for image in files]
@@ -146,9 +143,9 @@ examples = [
146
 
147
  demo = gr.ChatInterface(
148
  fn=model_inference,
149
- description="# **Multimodal OCR** `@aya-vision 'prompt..'`",
150
  examples=examples,
151
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
152
  stop_btn="Stop Generation",
153
  multimodal=True,
154
  cache_examples=False,
 
28
  AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @spaces.GPU
32
  def model_inference(input_dict, history):
33
  text = input_dict["text"].strip()
 
43
  # For simplicity, use the first provided image.
44
  image = load_image(files[0])
45
  yield "Processing with Aya-Vision β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–’β–’β–’ 69%"
46
+ messages = [{
47
+ "role": "user",
48
+ "content": [
49
+ {"type": "image", "image": image},
50
+ {"type": "text", "text": text_prompt},
51
+ ],
52
+ }]
53
+ inputs = aya_processor.apply_chat_template(
54
+ messages,
55
+ padding=True,
56
+ add_generation_prompt=True,
57
+ tokenize=True,
58
+ return_dict=True,
59
+ return_tensors="pt"
60
+ ).to(aya_model.device)
61
+ # Set up a streamer for Aya-Vision output
62
+ streamer = TextIteratorStreamer(aya_processor, skip_prompt=True, skip_special_tokens=True)
63
+ generation_kwargs = dict(
64
+ inputs,
65
+ streamer=streamer,
66
+ max_new_tokens=300,
67
+ do_sample=True,
68
+ temperature=0.3
69
+ )
70
+ thread = Thread(target=aya_model.generate, kwargs=generation_kwargs)
71
+ thread.start()
72
+ buffer = ""
73
+ for new_text in streamer:
74
+ buffer += new_text
75
+ buffer = buffer.replace("<|im_end|>", "")
76
+ time.sleep(0.01)
77
+ yield buffer
78
  return
79
+
80
  # Load images if provided.
81
  if len(files) > 1:
82
  images = [load_image(image) for image in files]
 
143
 
144
  demo = gr.ChatInterface(
145
  fn=model_inference,
146
+ description="# Multimodal OCR `@aya-vision 'prompt..'`",
147
  examples=examples,
148
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="By default, it runs Qwen2VL. Tag @aya-vision for Aya Vision 8B"),
149
  stop_btn="Stop Generation",
150
  multimodal=True,
151
  cache_examples=False,