Spaces:
Runtime error
Runtime error
File size: 3,556 Bytes
002fca8 2cd7197 9e3ea07 c53513a d707be1 2589dc0 dadb627 0099d95 c5f58d3 d707be1 b916cdf c53513a 2cd7197 9e3ea07 0a9fba8 a81da59 0a9fba8 56eafc1 c53513a a81da59 b8446db 68e04a7 a81da59 d707be1 a81da59 b916cdf a81da59 0a9fba8 a81da59 c53513a a81da59 33e9df8 c53513a 56eafc1 29bc3a1 56eafc1 29bc3a1 56eafc1 29bc3a1 4caf2cc 33e9df8 88cc3e6 a061413 c53513a 609a4fb ca7a52b c53513a f381f25 36c21dd ca7a52b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware # Importa il middleware CORS
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from datetime import datetime
from gradio_client import Client
import base64
import requests
import os
import socket
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class InputData(BaseModel):
input: str
temperature: float = 0.2
max_new_tokens: int = 30000
top_p: float = 0.95
repetition_penalty: float = 1.0
class InputImage(BaseModel):
input: str
negativePrompt: str = ''
steps: int = 25
cfg: int = 5
seed: int = 453666937
def format_prompt(message, history):
prompt = "<s>"
#with open('Manuale.txt', 'r') as file:
# manual_content = file.read()
# prompt += f"Leggi questo manuale dopo ti farò delle domande: {manual_content}"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
prompt += f"[{now}] [INST] {message} [/INST]"
return prompt
@app.post("/Genera")
def read_root(request: Request, input_data: InputData):
input_text = input_data.input
temperature = input_data.temperature
max_new_tokens = input_data.max_new_tokens
top_p = input_data.top_p
repetition_penalty = input_data.repetition_penalty
history = [] # Puoi definire la history se necessario
generated_response = generate(input_text, history, temperature, max_new_tokens, top_p, repetition_penalty)
return {"response": generated_response}
@app.post("/Immagine")
def generate_image(request: Request, input_data: InputImage):
client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/545b5tw7n/")
result = client.predict(
input_data.input,
input_data.negativePrompt,
input_data.steps,
input_data.cfg,
1024,
1024,
input_data.seed,
fn_index=0
)
image_url = result
with open(image_url, 'rb') as img_file:
img_binary = img_file.read()
img_base64 = base64.b64encode(img_binary).decode('utf-8')
return {"response": img_base64}
@app.get("/")
def read_general():
return {"response": "Benvenuto. Per maggiori info vai a /docs"} # Restituisci la risposta generata come JSON
def generate(prompt, history, temperature=0.2, max_new_tokens=30000, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False)
return output
#stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=False, return_full_text=False)
# Accumula l'output in una lista
#output_list = []
#for response in stream:
# output_list.append(response.token.text)
#return iter(output_list) # Restituisci la lista come un iteratore |