File size: 1,645 Bytes
0276e44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from transformers import pipeline
from fastapi import FastAPI
from pydantic import BaseModel

# Getting the prompt from the prompt.txt file
prompt_dir = "/prompt.txt"
prompt = ''
with open(prompt_dir, 'r') as file:
    prompt = file.read()


def post_process(essay):
    # Find the index of the first occurrence of the word "Feedback:"
    feedback_index = essay.find("Feedback:")

    # If "Feedback:" is not found, return the original essay
    if feedback_index == -1:
        return essay

    # Find the index of the newline after the first occurrence of "Feedback:"
    newline_index = essay.find("\n", feedback_index)

    # If newline is not found, return the original essay
    if newline_index == -1:
        return essay

    # Return the essay up to the newline after the first occurrence of "Feedback:"
    return essay[:newline_index]

def pre_process(instruction, essay):
    text = f"{instruction}\n{essay}"
    return text


pipe = pipeline(
    "text-generation", 
    model = "gildead/mistral-aes-414",
    device_map="auto"
)


class Message(BaseModel):
    essay: str
    instruction: str
        
app = FastAPI()


@app.get("/")
async def root():
    return {"message": "Mistral API is running."}

@app.post("/score")
async def overall(message: Message):
    text = pre_process(message.instruction, message.essay)
    result = pipe(
        f"<s>[INST] {text} [/INST]",
        max_new_tokens=200, 
        num_return_sequences=1,)
    
    generated_text = result[0]['generated_text']
    output = generated_text.split('[/INST]', 1)[-1].strip()
    final_output = post_process(output)

    return {"result": final_output}