alaa-ahmed14 commited on
Commit
19cd085
·
verified ·
1 Parent(s): d1153c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -1,10 +1,13 @@
1
- from fastapi import FastAPI
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
-
5
  import os
6
 
7
 
 
8
  # Set cache directory for Hugging Face Transformers
9
  os.environ["TRANSFORMERS_CACHE"] = "/home/user/.cache"
10
 
@@ -15,15 +18,26 @@ model = AutoModelForCausalLM.from_pretrained("matsant01/STEMerald-2b")
15
  # Initialize FastAPI app
16
  app = FastAPI()
17
 
18
-
19
-
20
- @app.get("/")
21
- def read_root():
22
- return {"message": "Welcome to the STEMerald-2b API"}
23
-
24
- #@app.post("/generate/")
25
- #def generate_text(prompt: str):
26
- # inputs = tokenizer(prompt, return_tensors="pt")
27
- # outputs = model.generate(inputs["input_ids"], max_length=50)
28
- # generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
- # return {"generated_text": generated_text}
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import HTMLResponse
4
+ from fastapi.staticfiles import StaticFiles
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import torch
 
7
  import os
8
 
9
 
10
+
11
  # Set cache directory for Hugging Face Transformers
12
  os.environ["TRANSFORMERS_CACHE"] = "/home/user/.cache"
13
 
 
18
  # Initialize FastAPI app
19
  app = FastAPI()
20
 
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"],
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+
30
+ # Serve the HTML file
31
+ @app.get("/", response_class=HTMLResponse)
32
+ async def read_root():
33
+ with open("index.html", "r") as f:
34
+ return f.read()
35
+
36
+ @app.post("/generate/")
37
+ async def generate_text(prompt: str):
38
+ if not prompt:
39
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
40
+ inputs = tokenizer(prompt, return_tensors="pt")
41
+ outputs = model.generate(inputs["input_ids"], max_length=50)
42
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+ return {"generated_text": generated_text}