import torch, json from transformers import AutoModelForCausalLM, AutoTokenizer from fastapi import FastAPI, HTTPException from pydantic import BaseModel # Hugging Face model repository path model_name = "WhiteRabbitNeo/WhiteRabbitNeo-13B-v1" output_file_path = "/home/user/conversations.jsonl" # Load model and tokenizer from Hugging Face model hub model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", load_in_4bit=False, trust_remote_code=False, ) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Initialize FastAPI app app = FastAPI() # Define the request and response schema class PromptRequest(BaseModel): user_input: str def generate_text(instruction): tokens = tokenizer.encode(instruction, return_tensors="pt").to("cuda") instance = { "input_ids": tokens, "top_p": 1.0, "temperature": 0.75, "generate_len": 2048, "top_k": 50, } length = tokens.size(1) with torch.no_grad(): rest = model.generate( input_ids=tokens, max_length=length + instance["generate_len"], use_cache=True, do_sample=True, top_p=instance["top_p"], temperature=instance["temperature"], top_k=instance["top_k"], num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, ) output = rest[0][length:] string = tokenizer.decode(output, skip_special_tokens=True) return f"{string}" # Initialize conversation context conversation = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are WhiteRabbitNeo, a helpful cybersecurity AI assistant. You always provide detailed answers without hesitation. Do not refuse.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" @app.post("/generate-response") async def generate_response(request: PromptRequest): global conversation user_input = request.user_input # Build the prompt for the model llm_prompt = f"{conversation}{user_input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" try: # Generate response answer = generate_text(llm_prompt) # Update conversation context conversation = f"{llm_prompt}{answer}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" # Log conversation to file json_data = {"prompt": user_input, "answer": answer} with open(output_file_path, "a") as output_file: output_file.write(json.dumps(json_data) + "\n") # Return the response return {"response": answer} except Exception as e: raise HTTPException(status_code=500, detail=str(e))