raccoote commited on
Commit
feec422
·
verified ·
1 Parent(s): 3605c25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -14
app.py CHANGED
@@ -1,21 +1,48 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import torch
 
3
 
4
- model_name = "raccoote/angry-birds-v1"
 
5
 
6
- # Load model and tokenizer
7
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
 
10
- # Ensure to use the model in evaluation mode to save memory
11
- model.eval()
 
 
 
 
12
 
13
- def generate_text(prompt):
14
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
15
- with torch.no_grad(): # Disable gradient calculation for inference
16
- outputs = model.generate(**inputs)
17
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Example usage
20
- response = generate_text("Hello, world!")
21
- print(response)
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
+ # Load the model and tokenizer
5
+ model_name = "raccoote/angry-birds-v1"
6
 
7
+ # Use half-precision if running on GPU
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Load the tokenizer
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
13
+ # Load the model with half-precision and low memory usage options
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
17
+ low_cpu_mem_usage=True
18
+ ).to(device)
19
 
20
+ # Function to generate responses
21
+ def generate_response(prompt):
22
+ # Tokenize input
23
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
24
+
25
+ # Generate output (inference mode with no gradient computation to save memory)
26
+ with torch.no_grad():
27
+ outputs = model.generate(
28
+ inputs["input_ids"],
29
+ max_length=150, # You can adjust the max length based on your needs
30
+ num_return_sequences=1,
31
+ do_sample=True, # Enable sampling to generate more varied responses
32
+ top_k=50, # Limits the sampled tokens to the top k choices to avoid unlikely words
33
+ top_p=0.95, # Nucleus sampling; keeps the cumulative probability of top tokens below a threshold
34
+ )
35
+
36
+ # Decode and return the response
37
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+ return response
39
 
40
+ # Simple loop to interact with the chatbot
41
+ if __name__ == "__main__":
42
+ print("Chatbot is ready! Type your message below (type 'exit' to quit):")
43
+ while True:
44
+ user_input = input("You: ")
45
+ if user_input.lower() == "exit":
46
+ break
47
+ response = generate_response(user_input)
48
+ print(f"Bot: {response}")