codeGenerator / main.py
Jatin2003's picture
Update main.py
b71faa4 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
app = FastAPI()
# Define the API URL for the model
API_URL = "https://api-inference.huggingface.co/models/meta-llama/CodeLlama-13b-hf"
# Define the request model
class GenerateRequest(BaseModel):
prompt: str
instructions: str
api_key: str
def format_prompt(message, custom_instructions=None):
"""
Format the prompt with optional custom instructions.
"""
prompt = ""
if custom_instructions:
prompt += f"{custom_instructions}\n"
prompt += f"{message}"
return prompt
def CodeLlama(prompt, instructions, api, temperature=0.1, max_new_tokens=2000, top_p=0.95, repetition_penalty=1.0):
"""
Generate text using the CodeLlama model with specified parameters.
"""
try:
# Ensure parameters are valid
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=69,
)
# Format the prompt
formatted_prompt = format_prompt(prompt, instructions)
# Initialize the client and make the request
head = {"Authorization": f"Bearer {api}"}
client = InferenceClient(API_URL, headers=head)
response = client.text_generation(formatted_prompt, **generate_kwargs)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
@app.post("/generate-text")
async def generate_text(request: GenerateRequest):
"""
FastAPI endpoint to generate text based on the input request.
"""
if not request.prompt or not request.instructions or not request.api_key:
raise HTTPException(status_code=400, detail="Missing required fields")
try:
response = CodeLlama(request.prompt, request.instructions, request.api_key)
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")