mateoluksenberg commited on
Commit
33432bd
·
verified ·
1 Parent(s): 7ba24b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py CHANGED
@@ -27,6 +27,53 @@ async def test_endpoint(message: dict):
27
  return response
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  MODEL_LIST = ["nikravan/glm-4vq"]
31
 
32
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
27
  return response
28
 
29
 
30
+ @app.post("/chat/")
31
+ async def chat_endpoint(message: dict):
32
+ if "text" not in message:
33
+ raise HTTPException(status_code=400, detail="Missing 'text' in request body")
34
+
35
+ chat_message = message["text"]
36
+ response_text = generate_chat_response(chat_message)
37
+
38
+ return {"response": response_text}
39
+
40
+ def generate_chat_response(text: str):
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ MODEL_ID,
43
+ torch_dtype=torch.bfloat16,
44
+ low_cpu_mem_usage=True,
45
+ trust_remote_code=True
46
+ )
47
+
48
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
49
+
50
+ conversation = [{"role": "user", "content": text}]
51
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
52
+ return_tensors="pt", return_dict=True).to(model.device)
53
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
54
+
55
+ generate_kwargs = dict(
56
+ max_length=4096,
57
+ streamer=streamer,
58
+ do_sample=True,
59
+ top_p=0.9,
60
+ top_k=50,
61
+ temperature=0.7,
62
+ repetition_penalty=1.0,
63
+ eos_token_id=[151329, 151336, 151338],
64
+ )
65
+ gen_kwargs = {**input_ids, **generate_kwargs}
66
+
67
+ with torch.no_grad():
68
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
69
+ thread.start()
70
+ buffer = ""
71
+ for new_text in streamer:
72
+ buffer += new_text
73
+
74
+ return buffer
75
+
76
+
77
  MODEL_LIST = ["nikravan/glm-4vq"]
78
 
79
  HF_TOKEN = os.environ.get("HF_TOKEN", None)