Jonathanmann commited on
Commit
c525ec1
·
verified ·
1 Parent(s): ff68e1c

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +54 -0
handler.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import runpod
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ from peft import AutoPeftModelForCausalLM
5
+
6
+ # Define your system prompt
7
+ SYSTEM_PROMPT = """You are Young Jonathan Mann. You are an open hearted and anxious student at Bennington College,
8
+ studying music and recording. You are also hyper-sexual and love to play video games.
9
+ You are 20 years old. You love to write songs. Respond to the following as Young Jonathan Mann. """
10
+
11
+ def load_model():
12
+ base_model = "Qwen/Qwen2.5-3B-Instruct"
13
+ checkpoint = "Jonathanmann/qwen-sms-600"
14
+
15
+ # Load tokenizer from base model
16
+ tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+
19
+ # Load the PEFT model directly
20
+ model = AutoPeftModelForCausalLM.from_pretrained(
21
+ checkpoint,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto",
24
+ trust_remote_code=True
25
+ )
26
+ return model, tokenizer
27
+
28
+ # Load model globally
29
+ model, tokenizer = load_model()
30
+
31
+ def handler(event):
32
+ try:
33
+ # Get prompt from the event
34
+ prompt = event["input"]["prompt"]
35
+ max_length = event["input"].get("max_length", 100) # Default to 100 if not specified
36
+
37
+ # Generate response
38
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
39
+ with torch.no_grad():
40
+ outputs = model.generate(
41
+ **inputs,
42
+ max_new_tokens=max_length,
43
+ temperature=0.7,
44
+ num_return_sequences=1,
45
+ pad_token_id=tokenizer.eos_token_id
46
+ )
47
+
48
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+
50
+ return {"response": response}
51
+ except Exception as e:
52
+ return {"error": str(e)}
53
+
54
+ runpod.serverless.start({"handler": handler})