prateekbh commited on
Commit
7d58261
1 Parent(s): 14901fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import gradio as gr
2
  import spaces
3
  import argparse
 
 
 
4
 
5
 
6
  parser = argparse.ArgumentParser()
@@ -23,6 +26,72 @@ div#col-container {
23
  }
24
  """
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  with gr.Blocks(css=css) as demo:
27
  with gr.Column(elem_id="col-container"):
28
  gr.HTML(title)
 
1
  import gradio as gr
2
  import spaces
3
  import argparse
4
+ import torch
5
+ from transformers import AutoModel, AutoProcessor
6
+ from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
7
 
8
 
9
  parser = argparse.ArgumentParser()
 
26
  }
27
  """
28
 
29
+ model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device)
30
+ processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True)
31
+
32
+ class StopOnTokens(StoppingCriteria):
33
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
34
+ stop_ids = [151645]
35
+ for stop_id in stop_ids:
36
+ if input_ids[0][-1] == stop_id:
37
+ return True
38
+ return False
39
+
40
+ @torch.no_grad()
41
+ def response(message, history, image):
42
+ stop = StopOnTokens()
43
+
44
+ messages = [{"role": "system", "content": "You are a helpful assistant."}]
45
+
46
+ for user_msg, assistant_msg in history:
47
+ messages.append({"role": "user", "content": user_msg})
48
+ messages.append({"role": "assistant", "content": assistant_msg})
49
+
50
+ if len(messages) == 1:
51
+ message = f" <image>{message}"
52
+
53
+ messages.append({"role": "user", "content": message})
54
+
55
+ model_inputs = processor.tokenizer.apply_chat_template(
56
+ messages,
57
+ add_generation_prompt=True,
58
+ return_tensors="pt"
59
+ )
60
+
61
+ image = (
62
+ processor.feature_extractor(image)
63
+ .unsqueeze(0)
64
+ )
65
+
66
+ attention_mask = torch.ones(
67
+ 1, model_inputs.shape[1] + processor.num_image_latents - 1
68
+ )
69
+
70
+ model_inputs = {
71
+ "input_ids": model_inputs,
72
+ "images": image,
73
+ "attention_mask": attention_mask
74
+ }
75
+
76
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
77
+
78
+ streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
79
+ generate_kwargs = dict(
80
+ model_inputs,
81
+ streamer=streamer,
82
+ max_new_tokens=1024,
83
+ stopping_criteria=StoppingCriteriaList([stop])
84
+ )
85
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
86
+ t.start()
87
+
88
+ history.append([message, ""])
89
+ partial_response = ""
90
+ for new_token in streamer:
91
+ partial_response += new_token
92
+ history[-1][1] = partial_response
93
+ yield history, gr.Button(visible=False), gr.Button(visible=True, interactive=True)
94
+
95
  with gr.Blocks(css=css) as demo:
96
  with gr.Column(elem_id="col-container"):
97
  gr.HTML(title)