Elaineyy commited on
Commit
ce5eee7
·
verified ·
1 Parent(s): 216abcb

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +44 -4
server.py CHANGED
@@ -3,13 +3,19 @@ from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
 
 
6
 
7
  app = FastAPI()
8
 
9
- # Fix: Set a writable cache directory
10
  os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
 
11
 
12
- # Load DeepSeek-Coder-V2-Base Model
 
 
 
13
  model_name = "deepseek-ai/DeepSeek-Coder-V2-Base"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
@@ -17,17 +23,51 @@ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float
17
  class CodeRequest(BaseModel):
18
  user_story: str
19
 
 
 
 
20
  @app.post("/generate-code")
21
  def generate_code(request: CodeRequest):
22
- """Generates code based on user story"""
23
  prompt = f"Generate structured code for: {request.user_story}"
24
-
25
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
26
  output = model.generate(**inputs, max_length=300)
27
  generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
28
 
29
  return {"generated_code": generated_code}
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if __name__ == "__main__":
32
  import uvicorn
33
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
6
+ import subprocess
7
+ import tempfile
8
 
9
  app = FastAPI()
10
 
11
+ # Fix: Set writable cache directory for Hugging Face models
12
  os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
13
+ os.environ["HF_HOME"] = "/app/cache"
14
 
15
+ # Ensure cache directory exists
16
+ os.makedirs("/app/cache", exist_ok=True)
17
+
18
+ # ✅ Load AI Model
19
  model_name = "deepseek-ai/DeepSeek-Coder-V2-Base"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
 
23
  class CodeRequest(BaseModel):
24
  user_story: str
25
 
26
+ class TestRequest(BaseModel):
27
+ code: str
28
+
29
  @app.post("/generate-code")
30
  def generate_code(request: CodeRequest):
31
+ """Generates AI-powered structured code based on user story"""
32
  prompt = f"Generate structured code for: {request.user_story}"
33
+
34
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
35
  output = model.generate(**inputs, max_length=300)
36
  generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
37
 
38
  return {"generated_code": generated_code}
39
 
40
+ @app.post("/test-code")
41
+ def test_code(request: TestRequest):
42
+ """Runs automated testing on AI-generated code"""
43
+ try:
44
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
45
+ temp_file.write(request.code.encode())
46
+ temp_file.close()
47
+
48
+ result = subprocess.run(["pytest", temp_file.name], capture_output=True, text=True)
49
+ os.unlink(temp_file.name)
50
+
51
+ if result.returncode == 0:
52
+ return {"test_status": "All tests passed!"}
53
+ else:
54
+ return {"test_status": "Test failed!", "details": result.stderr}
55
+
56
+ except Exception as e:
57
+ raise HTTPException(status_code=500, detail=str(e))
58
+
59
+ @app.get("/execute-code")
60
+ def execute_code():
61
+ """Executes AI-generated code and returns output"""
62
+ sample_code = "print('Hello from AI-generated code!')"
63
+
64
+ try:
65
+ result = subprocess.run(["python3", "-c", sample_code], capture_output=True, text=True)
66
+ return {"status": "Execution successful!", "output": result.stdout}
67
+
68
+ except Exception as e:
69
+ return {"status": "Execution failed!", "error": str(e)}
70
+
71
  if __name__ == "__main__":
72
  import uvicorn
73
  uvicorn.run(app, host="0.0.0.0", port=7860)