Marlon Wiprud commited on
Commit
68677a4
·
1 Parent(s): 104960d

feat: setup handler;

Browse files
Files changed (2) hide show
  1. handler.py +95 -0
  2. requirements.txt +11 -0
handler.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import requests
5
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
6
+ import torch
7
+
8
+ # from accelerate import (
9
+ # init_empty_weights,
10
+ # infer_auto_device_map,
11
+ # load_checkpoint_and_dispatch,
12
+ # )
13
+ import os
14
+
15
+ import logging
16
+
17
+ # from transformers import logging as hf_logging
18
+ # hf_logging.set_verbosity_debug()
19
+
20
+ logging.basicConfig(level=logging.INFO)
21
+
22
+
23
+ class EndpointHandler:
24
+ def __init__(self, path=""):
25
+ self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
26
+
27
+ self.model = (
28
+ AutoModelForCausalLM.from_pretrained(
29
+ "THUDM/cogvlm-grounding-generalist-hf",
30
+ torch_dtype=torch.bfloat16,
31
+ low_cpu_mem_usage=True,
32
+ trust_remote_code=True,
33
+ )
34
+ .to("cuda")
35
+ .eval()
36
+ )
37
+
38
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
39
+ """
40
+ data args:
41
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
42
+ kwargs
43
+ Return:
44
+ A :obj:`list` | `dict`: will be serialized and returned
45
+ """
46
+
47
+ query = data["inputs"]
48
+ img_uri = data["img_uri"]
49
+
50
+ image = Image.open(
51
+ requests.get(
52
+ img_uri,
53
+ stream=True,
54
+ ).raw
55
+ ).convert("RGB")
56
+
57
+ inputs = model.build_conversation_input_ids(
58
+ tokenizer, query=query, images=[image]
59
+ )
60
+ inputs = {
61
+ "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
62
+ "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
63
+ "attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
64
+ "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
65
+ }
66
+ gen_kwargs = {"max_length": 2048, "do_sample": False}
67
+
68
+ with torch.no_grad():
69
+ outputs = model.generate(**inputs, **gen_kwargs)
70
+ outputs = outputs[:, inputs["input_ids"].shape[1] :]
71
+ result = tokenizer.decode(outputs[0])
72
+ return result
73
+
74
+
75
+ # query = "How many houses are there in this cartoon?"
76
+ # image = Image.open(
77
+ # requests.get(
78
+ # "https://github.com/THUDM/CogVLM/blob/main/examples/3.jpg?raw=true", stream=True
79
+ # ).raw
80
+ # ).convert("RGB")
81
+ # inputs = model.build_conversation_input_ids(
82
+ # tokenizer, query=query, history=[], images=[image], template_version="vqa"
83
+ # ) # vqa mode
84
+ # inputs = {
85
+ # "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
86
+ # "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to("cuda"),
87
+ # "attention_mask": inputs["attention_mask"].unsqueeze(0).to("cuda"),
88
+ # "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)]],
89
+ # }
90
+ # gen_kwargs = {"max_length": 2048, "do_sample": False}
91
+
92
+ # with torch.no_grad():
93
+ # outputs = model.generate(**inputs, **gen_kwargs)
94
+ # outputs = outputs[:, inputs["input_ids"].shape[1] :]
95
+ # print(tokenizer.decode(outputs[0]))
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ Pillow==10.1.0
3
+ # torch==2.1.0
4
+ torch==1.13.1
5
+ # transformers==4.35.0
6
+ accelerate==0.24.1
7
+ sentencepiece==0.1.99
8
+ einops==0.7.0
9
+ # xformers==0.0.22.post7
10
+ xformers
11
+ triton==2.1.0