mistral-aes / app.py
Kid Omar Costelo
Add FastAPI endpoints for scoring essays and generating prompts
0276e44
raw
history blame
1.65 kB
from transformers import pipeline
from fastapi import FastAPI
from pydantic import BaseModel
# Getting the prompt from the prompt.txt file
prompt_dir = "/prompt.txt"
prompt = ''
with open(prompt_dir, 'r') as file:
prompt = file.read()
def post_process(essay):
# Find the index of the first occurrence of the word "Feedback:"
feedback_index = essay.find("Feedback:")
# If "Feedback:" is not found, return the original essay
if feedback_index == -1:
return essay
# Find the index of the newline after the first occurrence of "Feedback:"
newline_index = essay.find("\n", feedback_index)
# If newline is not found, return the original essay
if newline_index == -1:
return essay
# Return the essay up to the newline after the first occurrence of "Feedback:"
return essay[:newline_index]
def pre_process(instruction, essay):
text = f"{instruction}\n{essay}"
return text
pipe = pipeline(
"text-generation",
model = "gildead/mistral-aes-414",
device_map="auto"
)
class Message(BaseModel):
essay: str
instruction: str
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Mistral API is running."}
@app.post("/score")
async def overall(message: Message):
text = pre_process(message.instruction, message.essay)
result = pipe(
f"<s>[INST] {text} [/INST]",
max_new_tokens=200,
num_return_sequences=1,)
generated_text = result[0]['generated_text']
output = generated_text.split('[/INST]', 1)[-1].strip()
final_output = post_process(output)
return {"result": final_output}