Spaces:
Runtime error
Runtime error
File size: 4,399 Bytes
002fca8 2cd7197 9e3ea07 c53513a d707be1 2589dc0 dadb627 0099d95 c5f58d3 f0feabf d707be1 c550535 b916cdf c53513a 2cd7197 9e3ea07 0a9fba8 a81da59 0a9fba8 56eafc1 8d9d6f6 31c3ad2 56eafc1 0bb76a8 6159237 c550535 a81da59 0a9fba8 a81da59 c550535 a81da59 33e9df8 c53513a 609a4fb ca7a52b c550535 0bb76a8 c550535 31c3ad2 c550535 31c3ad2 c550535 31c3ad2 c550535 c53513a 0bb76a8 6159237 0bb76a8 6159237 c550535 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware # Importa il middleware CORS
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from datetime import datetime
from gradio_client import Client
import base64
import requests
import os
import socket
import time
#--------------------------------------------------- Definizione Server FAST API ------------------------------------------------------
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class InputData(BaseModel):
input: str
temperature: float = 0.2
max_new_tokens: int = 30000
top_p: float = 0.95
repetition_penalty: float = 1.0
class InputImage(BaseModel):
input: str
negativePrompt: str = ''
steps: int = 30
cfg: int = 7
seed: int = -1
class PostSpazio(BaseModel):
nomeSpazio: str
input: str = ''
api_name: str = "/chat"
#--------------------------------------------------- Generazione TESTO ------------------------------------------------------
@app.post("/Genera")
def read_root(request: Request, input_data: InputData):
input_text = input_data.input
temperature = input_data.temperature
max_new_tokens = input_data.max_new_tokens
top_p = input_data.top_p
repetition_penalty = input_data.repetition_penalty
history = []
generated_response = generate(input_text, history, temperature, max_new_tokens, top_p, repetition_penalty)
return {"response": generated_response}
def generate(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False)
return output
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
prompt += f"[{now}] [INST] {message} [/INST]"
return prompt
#--------------------------------------------------- Generazione IMMAGINE ------------------------------------------------------
@app.post("/Immagine")
def generate_image(request: Request, input_data: InputImage):
client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/545b5tw7n/")
max_attempts = 20
attempt = 0
while attempt < max_attempts:
try:
result = client.predict(
input_data.input,
input_data.negativePrompt,
input_data.steps,
input_data.cfg,
1024,
1024,
input_data.seed,
fn_index=0
)
image_url = result
with open(image_url, 'rb') as img_file:
img_binary = img_file.read()
img_base64 = base64.b64encode(img_binary).decode('utf-8')
return {"response": img_base64}
except requests.exceptions.HTTPError as e:
time.sleep(1)
attempt += 1
if attempt < max_attempts:
continue
else:
return {"error": "Errore interno del server persistente"}
return {"error": "Numero massimo di tentativi raggiunto"}
#--------------------------------------------------- API PostSpazio ------------------------------------------------------
@app.post("/PostSpazio")
def generate_image(request: Request, input_data: PostSpazio):
client = Client(input_data.nomeSpazio)
result = client.predict(
input_data.input,
api_name=input_data.api_name
)
return {"response": result}
@app.get("/")
def read_general():
return {"response": "Benvenuto. Per maggiori info: https://matteoscript-fastapi.hf.space/docs"}
|