File size: 3,423 Bytes
002fca8
2cd7197
9e3ea07
c53513a
d707be1
2589dc0
dadb627
 
0099d95
d707be1
b916cdf
c53513a
 
2cd7197
 
 
 
 
 
 
 
9e3ea07
0a9fba8
a81da59
 
 
 
0a9fba8
c53513a
a81da59
68e04a7
b8446db
 
 
68e04a7
a81da59
 
 
d707be1
 
 
a81da59
b916cdf
a81da59
0a9fba8
 
a81da59
 
 
 
 
c53513a
a81da59
0a9fba8
c53513a
88cc3e6
29bc3a1
 
 
 
 
 
 
 
 
 
 
 
0099d95
 
29bc3a1
88cc3e6
 
 
 
 
a061413
 
 
 
c53513a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609a4fb
ca7a52b
c53513a
f381f25
36c21dd
ca7a52b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware  # Importa il middleware CORS
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from datetime import datetime
from gradio_client import Client
import base64
import requests
import os

app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class InputData(BaseModel):
    input: str
    temperature: float = 0.2
    max_new_tokens: int = 30000
    top_p: float = 0.95
    repetition_penalty: float = 1.0

def format_prompt(message, history):
    prompt = "<s>"
    
    #with open('Manuale.txt', 'r') as file:
    #    manual_content = file.read()
    #    prompt += f"Leggi questo manuale dopo ti farò delle domande: {manual_content}"    
    
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
    prompt += f"[{now}] [INST] {message} [/INST]"

    return prompt

@app.post("/Genera")
def read_root(request: Request, input_data: InputData):
    input_text = input_data.input
    temperature = input_data.temperature
    max_new_tokens = input_data.max_new_tokens
    top_p = input_data.top_p
    repetition_penalty = input_data.repetition_penalty

    history = []  # Puoi definire la history se necessario
    generated_response = generate(input_text, history, temperature, max_new_tokens, top_p, repetition_penalty)
    return {"response": generated_response}

@app.get("/Immagine")
def generate_image():
    client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/545b5tw7n/")
    result = client.predict(
        "a giant monster hybrid of dragon and spider, in dark dense foggy forest ",
        "",
        25,
        5,
        1024,
        1024,
        453666937,
        fn_index=0
    )
    image_url = os.path.abspath(os.path.dirname(__file__)) + result
    return {"response": image_url}

@app.get("/Test")
def generate_image():
    result = "a giant monster hybrid of dragon and spider, in dark dense foggy forest "
    return {"response": result}

@app.get("/")
def read_general():
    return {"response": "Benvenuto. Per maggiori info vai a /docs"}  # Restituisci la risposta generata come JSON

def generate(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )
    formatted_prompt = format_prompt(prompt, history)
    output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False)
    return output

    #stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=False, return_full_text=False)
    # Accumula l'output in una lista
    #output_list = []
    #for response in stream:
    #    output_list.append(response.token.text)
    #return iter(output_list)  # Restituisci la lista come un iteratore