Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,8 @@ import torch
|
|
| 5 |
import optimum
|
| 6 |
from transformers import (AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, GenerationConfig, pipeline,)
|
| 7 |
|
| 8 |
-
app = FastAPI()
|
|
|
|
| 9 |
|
| 10 |
# Load the model and tokenizer
|
| 11 |
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
|
|
@@ -51,13 +52,17 @@ def load_model_norm():
|
|
| 51 |
return model, tokenizer
|
| 52 |
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
@app.get("/")
|
| 55 |
async def read_root():
|
| 56 |
return {"message": "Welcome to Eren Bot!"}
|
| 57 |
|
| 58 |
|
| 59 |
# Endpoint to start a new conversation thread
|
| 60 |
-
@app.post('/start_conversation')
|
| 61 |
async def start_conversation(request: Request):
|
| 62 |
data = await request.json()
|
| 63 |
prompt = data.get('prompt')
|
|
@@ -73,7 +78,7 @@ async def start_conversation(request: Request):
|
|
| 73 |
|
| 74 |
|
| 75 |
# Endpoint to get the response of a conversation thread
|
| 76 |
-
@app.get('/get_response/{thread_id}')
|
| 77 |
async def get_response(thread_id: int):
|
| 78 |
if thread_id not in conversations:
|
| 79 |
raise HTTPException(status_code=404, detail="Thread not found")
|
|
|
|
| 5 |
import optimum
|
| 6 |
from transformers import (AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, GenerationConfig, pipeline,)
|
| 7 |
|
| 8 |
+
app = FastAPI(title="Deploying FastAPI Apps on Huggingface")
|
| 9 |
+
|
| 10 |
|
| 11 |
# Load the model and tokenizer
|
| 12 |
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
|
|
|
|
| 52 |
return model, tokenizer
|
| 53 |
|
| 54 |
|
| 55 |
+
@app.get("/", tags=["Home"])
|
| 56 |
+
async def api_home():
|
| 57 |
+
return {'detail': 'Welcome to Eren Bot!'}
|
| 58 |
+
|
| 59 |
@app.get("/")
|
| 60 |
async def read_root():
|
| 61 |
return {"message": "Welcome to Eren Bot!"}
|
| 62 |
|
| 63 |
|
| 64 |
# Endpoint to start a new conversation thread
|
| 65 |
+
@app.post('/api/start_conversation')
|
| 66 |
async def start_conversation(request: Request):
|
| 67 |
data = await request.json()
|
| 68 |
prompt = data.get('prompt')
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# Endpoint to get the response of a conversation thread
|
| 81 |
+
@app.get('/api/get_response/{thread_id}')
|
| 82 |
async def get_response(thread_id: int):
|
| 83 |
if thread_id not in conversations:
|
| 84 |
raise HTTPException(status_code=404, detail="Thread not found")
|