from llama_cpp import Llama from fastapi import FastAPI, Form, HTTPException from pydantic import BaseModel import torch import selfies as sf app = FastAPI(title="Retrosynthesis Prediction API", version="1.0") # Load the model and tokenizer model_name = "model.gguf" # Replace with your model name test_model = Llama(model_name) class RequestBody(BaseModel): prompt: str temperature: float = 1.0 top_k: int = 50 top_p: float = 1.0 @app.post("/generate/") async def generate_text(request: RequestBody): try: prompt = sf.encoder(request.prompt) outputs = test_model( prompt, max_new_tokens=512, num_beams=10, early_stopping=True, num_return_sequences=10, do_sample=True, top_k = request.top_k, top_p = request.top_p, temperature = request.temperature ) result = {'input': prompt} for i in range(10): output1 = outputs[i][len(prompt):] first_inst_index = output1.find("[/INST]") second_inst_index = output1.find("[/IN", first_inst_index + len("[/INST]") + 1) predicted_selfies = output1[first_inst_index + len("[/INST]"):second_inst_index].strip() result[f'predict_{i+1}'] = sf.decoder(predicted_selfies) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def read_root(): return {"message": "Welcome to the RetroLLM app!"}