File size: 3,909 Bytes
002fca8
2cd7197
9e3ea07
c53513a
d707be1
2589dc0
dadb627
 
0099d95
c5f58d3
f0feabf
d707be1
c550535
b916cdf
c53513a
 
2cd7197
 
 
 
 
 
 
 
9e3ea07
0a9fba8
a81da59
 
 
 
0a9fba8
56eafc1
 
 
 
31c3ad2
 
56eafc1
c550535
a81da59
0a9fba8
 
a81da59
 
 
 
c550535
a81da59
33e9df8
c53513a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609a4fb
ca7a52b
c550535
 
 
 
 
 
 
 
 
 
 
 
 
 
31c3ad2
c550535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c3ad2
 
 
 
c550535
31c3ad2
c550535
c53513a
c550535
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware  # Importa il middleware CORS
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from datetime import datetime
from gradio_client import Client
import base64
import requests
import os
import socket
import time

#--------------------------------------------------- Definizione Server FAST API ------------------------------------------------------
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class InputData(BaseModel):
    input: str
    temperature: float = 0.2
    max_new_tokens: int = 30000
    top_p: float = 0.95
    repetition_penalty: float = 1.0

class InputImage(BaseModel):
    input: str
    negativePrompt: str = '' 
    steps: int = 25
    cfg: int = 7
    seed: int = -1

#--------------------------------------------------- Generazione TESTO ------------------------------------------------------
@app.post("/Genera")
def read_root(request: Request, input_data: InputData):
    input_text = input_data.input
    temperature = input_data.temperature
    max_new_tokens = input_data.max_new_tokens
    top_p = input_data.top_p
    repetition_penalty = input_data.repetition_penalty
    history = [] 
    generated_response = generate(input_text, history, temperature, max_new_tokens, top_p, repetition_penalty)
    return {"response": generated_response}

def generate(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )
    formatted_prompt = format_prompt(prompt, history)
    output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False)
    return output
    
def format_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
    prompt += f"[{now}] [INST] {message} [/INST]"
    return prompt
    
#--------------------------------------------------- Generazione IMMAGINE ------------------------------------------------------
@app.post("/Immagine")
def generate_image(request: Request, input_data: InputImage):
    client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/545b5tw7n/")
    max_attempts = 20
    attempt = 0
    while attempt < max_attempts:
        try:
            result = client.predict(
                input_data.input,
                input_data.negativePrompt,
                input_data.steps, 
                input_data.cfg,
                1024,
                1024,
                input_data.seed,
                fn_index=0
            )
            image_url = result
            with open(image_url, 'rb') as img_file:
                img_binary = img_file.read()
                img_base64 = base64.b64encode(img_binary).decode('utf-8')
            return {"response": img_base64}
        except requests.exceptions.HTTPError as e:
            time.sleep(1)
            attempt += 1
            if attempt < max_attempts:
                continue
            else:
                return {"error": "Errore interno del server persistente"}
    return {"error": "Numero massimo di tentativi raggiunto"}

@app.get("/")
def read_general():
    return {"response": "Benvenuto. Per maggiori info: https://matteoscript-fastapi.hf.space/docs"}