from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse import torch # from transformers import pipeline from transformers import AutoTokenizer, AutoModelForCausalLM app = FastAPI() # MODEL = "google/flan-t5-small" # MODEL = "jingyaogong/minimind-v1-small" MODEL = "tclh123/minimind-v1-small" # pipe_flan = pipeline("text2text-generation", model=MODEL, trust_remote_code=True) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(MODEL, trust_remote_code=True) model = model.to(device) model = model.eval() def query(message, max_seq_len=512, temperature=0.7, top_k=16): prompt = '请问,' + message messages = [] messages.append({"role": "user", "content": prompt}) stream = True # print(messages) new_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True )[-(max_seq_len - 1):] x = tokenizer(new_prompt).data['input_ids'] x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...]) res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature, top_k=top_k, stream=stream) try: y = next(res_y) except StopIteration: # print("No answer") return "" ret = [] history_idx = 0 while y != None: answer = tokenizer.decode(y[0].tolist()) if answer and answer[-1] == '�': try: y = next(res_y) except: break continue # print(answer) if not len(answer): try: y = next(res_y) except: break continue # print(answer[history_idx:], end='', flush=True) ret.append(answer[history_idx:]) try: y = next(res_y) except: break history_idx = len(answer) if not stream: break # print('\n') ret.append('\n') return ''.join(ret) @app.get("/infer_t5") def t5(input): # output = pipe_flan(input) # return {"output": output[0]["generated_text"]} output = query(input) return {"output": output} app.mount("/", StaticFiles(directory="static", html=True), name="static") @app.get("/") def index() -> FileResponse: return FileResponse(path="/app/static/index.html", media_type="text/html")