File size: 3,556 Bytes
002fca8
2cd7197
9e3ea07
c53513a
d707be1
2589dc0
dadb627
 
0099d95
c5f58d3
d707be1
b916cdf
c53513a
 
2cd7197
 
 
 
 
 
 
 
9e3ea07
0a9fba8
a81da59
 
 
 
0a9fba8
56eafc1
 
 
 
 
 
 
c53513a
a81da59
b8446db
 
 
68e04a7
a81da59
 
 
d707be1
 
 
a81da59
b916cdf
a81da59
0a9fba8
 
a81da59
 
 
 
 
c53513a
a81da59
33e9df8
c53513a
56eafc1
 
29bc3a1
 
56eafc1
 
 
 
29bc3a1
 
56eafc1
29bc3a1
 
4caf2cc
 
 
 
33e9df8
88cc3e6
a061413
 
 
 
c53513a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609a4fb
ca7a52b
c53513a
f381f25
36c21dd
ca7a52b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware  # Importa il middleware CORS
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from datetime import datetime
from gradio_client import Client
import base64
import requests
import os
import socket

app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class InputData(BaseModel):
    input: str
    temperature: float = 0.2
    max_new_tokens: int = 30000
    top_p: float = 0.95
    repetition_penalty: float = 1.0

class InputImage(BaseModel):
    input: str
    negativePrompt: str = '' 
    steps: int = 25
    cfg: int = 5
    seed: int = 453666937

def format_prompt(message, history):
    prompt = "<s>"
    #with open('Manuale.txt', 'r') as file:
    #    manual_content = file.read()
    #    prompt += f"Leggi questo manuale dopo ti farò delle domande: {manual_content}"    
    
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
    prompt += f"[{now}] [INST] {message} [/INST]"

    return prompt

@app.post("/Genera")
def read_root(request: Request, input_data: InputData):
    input_text = input_data.input
    temperature = input_data.temperature
    max_new_tokens = input_data.max_new_tokens
    top_p = input_data.top_p
    repetition_penalty = input_data.repetition_penalty

    history = []  # Puoi definire la history se necessario
    generated_response = generate(input_text, history, temperature, max_new_tokens, top_p, repetition_penalty)
    return {"response": generated_response}

@app.post("/Immagine")
def generate_image(request: Request, input_data: InputImage):
    client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/545b5tw7n/")
    result = client.predict(
        input_data.input,
        input_data.negativePrompt,
        input_data.steps, 
        input_data.cfg,
        1024,
        1024,
        input_data.seed,
        fn_index=0
    )
    image_url = result
    with open(image_url, 'rb') as img_file:
        img_binary = img_file.read()
        img_base64 = base64.b64encode(img_binary).decode('utf-8')
    return {"response": img_base64}    

@app.get("/")
def read_general():
    return {"response": "Benvenuto. Per maggiori info vai a /docs"}  # Restituisci la risposta generata come JSON

def generate(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )
    formatted_prompt = format_prompt(prompt, history)
    output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False)
    return output

    #stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=False, return_full_text=False)
    # Accumula l'output in una lista
    #output_list = []
    #for response in stream:
    #    output_list.append(response.token.text)
    #return iter(output_list)  # Restituisci la lista come un iteratore