Spaces:
Runtime error
Runtime error
File size: 4,063 Bytes
002fca8 2cd7197 9e3ea07 c53513a d707be1 2589dc0 dadb627 0099d95 c5f58d3 f0feabf d707be1 c550535 b916cdf c53513a 2cd7197 9e3ea07 0a9fba8 a81da59 0a9fba8 56eafc1 c550535 a81da59 0a9fba8 a81da59 c550535 a81da59 33e9df8 c53513a 609a4fb ca7a52b c550535 f0feabf c550535 c53513a c550535 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware # Importa il middleware CORS
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from datetime import datetime
from gradio_client import Client
import base64
import requests
import os
import socket
import time
#--------------------------------------------------- Definizione Server FAST API ------------------------------------------------------
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class InputData(BaseModel):
input: str
temperature: float = 0.2
max_new_tokens: int = 30000
top_p: float = 0.95
repetition_penalty: float = 1.0
class InputImage(BaseModel):
input: str
negativePrompt: str = ''
steps: int = 25
cfg: int = 5
seed: int = 453666937
#--------------------------------------------------- Generazione TESTO ------------------------------------------------------
@app.post("/Genera")
def read_root(request: Request, input_data: InputData):
input_text = input_data.input
temperature = input_data.temperature
max_new_tokens = input_data.max_new_tokens
top_p = input_data.top_p
repetition_penalty = input_data.repetition_penalty
history = []
generated_response = generate(input_text, history, temperature, max_new_tokens, top_p, repetition_penalty)
return {"response": generated_response}
def generate(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False)
return output
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
prompt += f"[{now}] [INST] {message} [/INST]"
return prompt
#--------------------------------------------------- Generazione IMMAGINE ------------------------------------------------------
@app.post("/Immagine")
def generate_image(request: Request, input_data: InputImage):
client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/545b5tw7n/")
max_attempts = 10
attempt = 0
while attempt < max_attempts:
try:
result = client.predict(
input_data.input,
input_data.negativePrompt,
input_data.steps,
input_data.cfg,
1024,
1024,
input_data.seed,
fn_index=0
)
image_url = result
with open(image_url, 'rb') as img_file:
img_binary = img_file.read()
img_base64 = base64.b64encode(img_binary).decode('utf-8')
return {"response": img_base64}
except requests.exceptions.HTTPError as e:
if e.response.status_code == 500:
time.sleep(1)
attempt += 1
if attempt < max_attempts:
continue
else:
return {"error": "Errore interno del server persistente"}
else:
return {"error": "Errore diverso da 500"}
return {"error": "Numero massimo di tentativi raggiunto"}
@app.get("/")
def read_general():
return {"response": "Benvenuto. Per maggiori info: https://matteoscript-fastapi.hf.space/docs"}
|