Hjgugugjhuhjggg commited on
Commit
edb59e5
·
verified ·
1 Parent(s): 8ae0197

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import StreamingResponse
5
+ from pydantic import BaseModel, field_validator
6
+ from transformers import (
7
+ AutoConfig,
8
+ pipeline,
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ GenerationConfig,
12
+ StoppingCriteriaList
13
+ )
14
+ import uvicorn
15
+ import asyncio
16
+ from io import BytesIO
17
+ from transformers import pipeline
18
+
19
+ app = FastAPI()
20
+
21
+ class GenerateRequest(BaseModel):
22
+ model_name: str
23
+ input_text: str = ""
24
+ task_type: str
25
+ temperature: float = 1.0
26
+ max_new_tokens: int = 200
27
+ stream: bool = True
28
+ top_p: float = 1.0
29
+ top_k: int = 50
30
+ repetition_penalty: float = 1.0
31
+ num_return_sequences: int = 1
32
+ do_sample: bool = True
33
+ chunk_delay: float = 0.0
34
+ stop_sequences: list[str] = []
35
+
36
+ @field_validator("model_name")
37
+ def model_name_cannot_be_empty(cls, v):
38
+ if not v:
39
+ raise ValueError("model_name cannot be empty.")
40
+ return v
41
+
42
+ @field_validator("task_type")
43
+ def task_type_must_be_valid(cls, v):
44
+ valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
45
+ if v not in valid_types:
46
+ raise ValueError(f"task_type must be one of: {valid_types}")
47
+ return v
48
+
49
+ class LocalModelLoader:
50
+ def __init__(self):
51
+ pass
52
+
53
+ async def load_model_and_tokenizer(self, model_name):
54
+ try:
55
+ try:
56
+ config = AutoConfig.from_pretrained(model_name)
57
+ tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
58
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
59
+
60
+ if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
61
+ tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
62
+
63
+ return model, tokenizer
64
+ except Exception as e:
65
+ raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
66
+
67
+ model_loader = LocalModelLoader()
68
+
69
+ @app.post("/generate")
70
+ async def generate(request: GenerateRequest):
71
+ try:
72
+ model_name = request.model_name
73
+ input_text = request.input_text
74
+ task_type = request.task_type
75
+ temperature = request.temperature
76
+ max_new_tokens = request.max_new_tokens
77
+ stream = request.stream
78
+ top_p = request.top_p
79
+ top_k = request.top_k
80
+ repetition_penalty = request.repetition_penalty
81
+ num_return_sequences = request.num_return_sequences
82
+ do_sample = request.do_sample
83
+ chunk_delay = request.chunk_delay
84
+ stop_sequences = request.stop_sequences
85
+
86
+ model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
87
+ device = "cuda" if torch.cuda.is_available() else "cpu"
88
+ model.to(device)
89
+
90
+ generation_config = GenerationConfig(
91
+ temperature=temperature,
92
+ max_new_tokens=max_new_tokens,
93
+ top_p=top_p,
94
+ top_k=top_k,
95
+ repetition_penalty=repetition_penalty,
96
+ do_sample=do_sample,
97
+ num_return_sequences=num_return_sequences,
98
+ )
99
+
100
+ return StreamingResponse(
101
+ stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
102
+ media_type="text/plain"
103
+ )
104
+
105
+ except Exception as e:
106
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
107
+
108
+ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
109
+ encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
110
+ input_length = encoded_input["input_ids"].shape[1]
111
+ remaining_tokens = max_length - input_length
112
+
113
+ if remaining_tokens <= 0:
114
+ yield ""
115
+
116
+ generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
117
+
118
+ def stop_criteria(input_ids, scores):
119
+ decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
120
+ return decoded_output in stop_sequences
121
+
122
+ stopping_criteria = StoppingCriteriaList([stop_criteria])
123
+
124
+ output_text = ""
125
+ outputs = model.generate(
126
+ **encoded_input,
127
+ do_sample=generation_config.do_sample,
128
+ max_new_tokens=generation_config.max_new_tokens,
129
+ temperature=generation_config.temperature,
130
+ top_p=generation_config.top_p,
131
+ top_k=generation_config.top_k,
132
+ repetition_penalty=generation_config.repetition_penalty,
133
+ num_return_sequences=generation_config.num_return_sequences,
134
+ stopping_criteria=stopping_criteria,
135
+ output_scores=True,
136
+ return_dict_in_generate=True
137
+ )
138
+
139
+ for output in outputs.sequences:
140
+ for token_id in output:
141
+ token = tokenizer.decode(token_id, skip_special_tokens=True)
142
+ yield token
143
+ await asyncio.sleep(chunk_delay)
144
+
145
+ if stop_sequences and any(stop in output_text for stop in stop_sequences):
146
+ yield output_text
147
+ return
148
+
149
+ outputs = model.generate(
150
+ **encoded_input,
151
+ do_sample=generation_config.do_sample,
152
+ max_new_tokens=generation_config.max_new_tokens,
153
+ temperature=generation_config.temperature,
154
+ top_p=generation_config.top_p,
155
+ top_k=generation_config.top_k,
156
+ repetition_penalty=generation_config.repetition_penalty,
157
+ num_return_sequences=generation_config.num_return_sequences,
158
+ stopping_criteria=stopping_criteria,
159
+ output_scores=True,
160
+ return_dict_in_generate=True
161
+ )
162
+
163
+ @app.post("/generate-image")
164
+ async def generate_image(request: GenerateRequest):
165
+ try:
166
+ validated_body = request
167
+ device = "cuda" if torch.cuda.is_available() else "cpu"
168
+
169
+ image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
170
+ image = image_generator(validated_body.input_text)[0]
171
+
172
+ img_byte_arr = BytesIO()
173
+ image.save(img_byte_arr, format="PNG")
174
+ img_byte_arr.seek(0)
175
+
176
+ return StreamingResponse(img_byte_arr, media_type="image/png")
177
+
178
+ except Exception as e:
179
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
180
+
181
+ @app.post("/generate-text-to-speech")
182
+ async def generate_text_to_speech(request: GenerateRequest):
183
+ try:
184
+ validated_body = request
185
+ device = "cuda" if torch.cuda.is_available() else "cpu"
186
+
187
+ audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
188
+ audio = audio_generator(validated_body.input_text)[0]
189
+
190
+ audio_byte_arr = BytesIO()
191
+ audio.save(audio_byte_arr)
192
+ audio_byte_arr.seek(0)
193
+
194
+ return StreamingResponse(audio_byte_arr, media_type="audio/wav")
195
+
196
+ except Exception as e:
197
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
198
+
199
+ @app.post("/generate-video")
200
+ async def generate_video(request: GenerateRequest):
201
+ try:
202
+ validated_body = request
203
+ device = "cuda" if torch.cuda.is_available() else "cpu"
204
+ video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
205
+ video = video_generator(validated_body.input_text)[0]
206
+
207
+ video_byte_arr = BytesIO()
208
+ video.save(video_byte_arr)
209
+ video_byte_arr.seek(0)
210
+
211
+ return StreamingResponse(video_byte_arr, media_type="video/mp4")
212
+
213
+ except Exception as e:
214
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
215
+
216
+ if __name__ == "__main__":
217
+ uvicorn.run(app, host="0.0.0.0", port=7860)