Faizal2805 commited on
Commit
958c61e
·
verified ·
1 Parent(s): c608035

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -20
app.py CHANGED
@@ -1,36 +1,29 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
  import torch
5
 
6
- # Load Model from Hugging Face Model Hub
7
- MODEL_NAME = "your_username/aws-bot-model" # Replace with your actual Hugging Face model path
8
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
10
 
11
- def generate_response(user_input: str):
12
- inputs = tokenizer(user_input, return_tensors="pt", truncation=True, max_length=128)
13
- outputs = model.generate(**inputs, max_length=128)
14
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
15
- return response
16
 
17
- # FastAPI App
18
  app = FastAPI()
19
 
20
  class Query(BaseModel):
21
- message: str
22
 
23
  @app.post("/chat")
24
  def chat(query: Query):
25
- user_input = query.message.lower()
26
-
27
- # Placeholder for external API handling (Uncomment when you have API key)
28
- # if is_unrelated_query(user_input):
29
- # return {"response": handle_external_query(user_input)}
30
-
31
- bot_response = generate_response(user_input)
32
- return {"response": bot_response}
33
 
34
  @app.get("/")
35
  def root():
36
- return {"message": "AWS Bot API is running!"}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
 
6
+ # Load model and tokenizer
7
+ MODEL_NAME = "meta-llama/Llama-3.2-1B" # Replace with your model
8
+
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
11
 
12
+ def generate_response(prompt: str):
13
+ inputs = tokenizer(prompt, return_tensors="pt")
14
+ outputs = model.generate(**inputs, max_length=200)
15
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
16
 
 
17
  app = FastAPI()
18
 
19
  class Query(BaseModel):
20
+ text: str
21
 
22
  @app.post("/chat")
23
  def chat(query: Query):
24
+ response = generate_response(query.text)
25
+ return {"response": response}
 
 
 
 
 
 
26
 
27
  @app.get("/")
28
  def root():
29
+ return {"message": "AWS-Bot is running!"}