Hjgugugjhuhjggg commited on
Commit
5ded4bc
·
verified ·
1 Parent(s): c938099

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -54
app.py CHANGED
@@ -9,7 +9,8 @@ from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  GenerationConfig,
12
- StoppingCriteriaList
 
13
  )
14
  import uvicorn
15
  import asyncio
@@ -48,9 +49,11 @@ class GenerateRequest(BaseModel):
48
 
49
  class LocalModelLoader:
50
  def __init__(self):
51
- pass
52
 
53
  async def load_model_and_tokenizer(self, model_name):
 
 
54
  try:
55
  config = AutoConfig.from_pretrained(model_name)
56
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
@@ -59,12 +62,24 @@ class LocalModelLoader:
59
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
60
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
61
 
 
62
  return model, tokenizer
63
  except Exception as e:
64
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
65
 
66
  model_loader = LocalModelLoader()
67
 
 
 
 
 
 
 
 
 
 
 
 
68
  @app.post("/generate")
69
  async def generate(request: GenerateRequest):
70
  try:
@@ -96,69 +111,40 @@ async def generate(request: GenerateRequest):
96
  num_return_sequences=num_return_sequences,
97
  )
98
 
 
 
 
 
 
 
99
  return StreamingResponse(
100
- stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
101
  media_type="text/plain"
102
  )
103
 
104
  except Exception as e:
105
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
106
 
107
- async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
108
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
109
- input_length = encoded_input["input_ids"].shape[1]
110
- remaining_tokens = max_length - input_length
111
-
112
- if remaining_tokens <= 0:
113
- yield ""
114
-
115
- generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
116
-
117
- def stop_criteria(input_ids, scores):
118
- decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
119
- return decoded_output in stop_sequences
120
-
121
- stopping_criteria = StoppingCriteriaList([stop_criteria])
122
-
123
- output_text = ""
124
- outputs = model.generate(
125
- **encoded_input,
126
- do_sample=generation_config.do_sample,
127
- max_new_tokens=generation_config.max_new_tokens,
128
- temperature=generation_config.temperature,
129
- top_p=generation_config.top_p,
130
- top_k=generation_config.top_k,
131
- repetition_penalty=generation_config.repetition_penalty,
132
- num_return_sequences=generation_config.num_return_sequences,
133
- stopping_criteria=stopping_criteria,
134
- output_scores=True,
135
- return_dict_in_generate=True
136
- )
137
-
138
- for output in outputs.sequences:
139
- for token_id in output:
140
- token = tokenizer.decode(token_id, skip_special_tokens=True)
141
- yield token
142
- await asyncio.sleep(chunk_delay)
143
-
144
- if stop_sequences and any(stop in output_text for stop in stop_sequences):
145
- yield output_text
146
- return
147
-
148
- outputs = model.generate(
149
  **encoded_input,
150
- do_sample=generation_config.do_sample,
151
- max_new_tokens=generation_config.max_new_tokens,
152
- temperature=generation_config.temperature,
153
- top_p=generation_config.top_p,
154
- top_k=generation_config.top_k,
155
- repetition_penalty=generation_config.repetition_penalty,
156
- num_return_sequences=generation_config.num_return_sequences,
157
- stopping_criteria=stopping_criteria,
158
- output_scores=True,
159
- return_dict_in_generate=True
160
  )
161
 
 
 
 
 
 
 
 
162
  @app.post("/generate-image")
163
  async def generate_image(request: GenerateRequest):
164
  try:
 
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  GenerationConfig,
12
+ StoppingCriteriaList,
13
+ StoppingCriteria
14
  )
15
  import uvicorn
16
  import asyncio
 
49
 
50
  class LocalModelLoader:
51
  def __init__(self):
52
+ self.loaded_models = {}
53
 
54
  async def load_model_and_tokenizer(self, model_name):
55
+ if model_name in self.loaded_models:
56
+ return self.loaded_models[model_name]
57
  try:
58
  config = AutoConfig.from_pretrained(model_name)
59
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
 
62
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
63
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
64
 
65
+ self.loaded_models[model_name] = (model, tokenizer)
66
  return model, tokenizer
67
  except Exception as e:
68
  raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
69
 
70
  model_loader = LocalModelLoader()
71
 
72
+ class StopOnTokens(StoppingCriteria):
73
+ def __init__(self, stop_token_ids: list[int]):
74
+ self.stop_token_ids = stop_token_ids
75
+
76
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
77
+ for stop_id in self.stop_token_ids:
78
+ if input_ids[0][-1] == stop_id:
79
+ return True
80
+ return False
81
+
82
+
83
  @app.post("/generate")
84
  async def generate(request: GenerateRequest):
85
  try:
 
111
  num_return_sequences=num_return_sequences,
112
  )
113
 
114
+ stop_token_ids = []
115
+ if stop_sequences:
116
+ stop_token_ids = tokenizer.convert_tokens_to_ids(stop_sequences)
117
+ stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
118
+
119
+
120
  return StreamingResponse(
121
+ stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
122
  media_type="text/plain"
123
  )
124
 
125
  except Exception as e:
126
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
127
 
128
+ async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=2048):
129
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
130
+
131
+ with torch.no_grad():
132
+ streamer = model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  **encoded_input,
134
+ generation_config=generation_config,
135
+ stopping_criteria=stopping_criteria_list,
136
+ stream=True, # Ensure streaming is enabled if supported by the model
137
+ return_dict_in_generate=True,
138
+ output_scores=True
 
 
 
 
 
139
  )
140
 
141
+ for output in streamer.sequences[:, encoded_input["input_ids"].shape[-1]:]: # Stream from the *new* tokens
142
+ token = tokenizer.decode(output, skip_special_tokens=True)
143
+ if token: # Avoid yielding empty tokens
144
+ yield token
145
+ await asyncio.sleep(chunk_delay)
146
+
147
+
148
  @app.post("/generate-image")
149
  async def generate_image(request: GenerateRequest):
150
  try: