OjciecTadeusz commited on
Commit
37e4010
·
verified ·
1 Parent(s): e74b14f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -64
app.py CHANGED
@@ -1,81 +1,109 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
4
  import torch
 
 
 
 
5
 
 
6
  app = FastAPI()
7
 
8
- # Model configuration
9
- MODEL_NAME = "nlptown/bert-base-multilingual-uncased-sentiment"
10
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
- # Initialize sentiment analysis model
13
- sentiment_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
- sentiment_classifier = pipeline(
15
- "sentiment-analysis",
16
- model=MODEL_NAME,
17
- tokenizer=sentiment_tokenizer,
18
- device=DEVICE
19
  )
20
 
21
- # Initialize GPT-2 for text generation
22
- MODEL_NAME_LARGE = "gpt2-large"
23
- generation_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_LARGE)
24
- generation_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_LARGE).to(DEVICE)
25
-
26
- class TextInput(BaseModel):
27
- text: str
28
-
29
- class GenerationInput(BaseModel):
30
- prompt: str
31
- max_length: int = 100
32
-
33
- @app.post("/analyze-sentiment")
34
- async def analyze_sentiment(input_data: TextInput):
35
- try:
36
- result = sentiment_classifier(input_data.text)
37
- return {
38
- "sentiment": result[0]['label'],
39
- "score": float(result[0]['score'])
40
  }
41
- except Exception as e:
42
- raise HTTPException(status_code=500, detail=str(e))
43
 
44
- @app.post("/generate-text")
45
- async def generate_text(input_data: GenerationInput):
46
  try:
47
- inputs = generation_tokenizer(
48
- input_data.prompt,
49
- return_tensors="pt"
50
- ).to(DEVICE)
 
 
 
 
 
 
51
 
52
- outputs = generation_model.generate(
53
- inputs["input_ids"],
54
- max_length=input_data.max_length,
55
- num_return_sequences=1,
56
- no_repeat_ngram_size=2,
57
- pad_token_id=generation_tokenizer.eos_token_id
58
  )
59
 
60
- generated_text = generation_tokenizer.decode(
61
- outputs[0],
62
- skip_special_tokens=True
 
 
 
 
 
 
 
 
63
  )
64
 
65
- return {"generated_text": generated_text}
 
 
 
 
 
66
  except Exception as e:
67
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
68
 
69
- @app.get("/health")
70
- async def health_check():
71
- return {
72
- "status": "healthy",
73
- "sentiment_model": MODEL_NAME,
74
- "generation_model": MODEL_NAME_LARGE,
75
- "device": str(DEVICE)
76
- }
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Dodaj to na końcu pliku
79
- if __name__ == "__main__":
80
- import uvicorn
81
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  import torch
4
+ import json
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.responses import JSONResponse
7
+ import datetime
8
 
9
+ # Initialize FastAPI
10
  app = FastAPI()
11
 
12
+ # Load model and tokenizer
13
+ model_name = "Qwen/Qwen2.5-Coder-32B"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_name,
17
+ device_map="auto",
18
+ trust_remote_code=True,
19
+ torch_dtype=torch.float16
 
 
 
20
  )
21
 
22
+ def format_chat_response(response_text, prompt_tokens, completion_tokens):
23
+ return {
24
+ "id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
25
+ "object": "chat.completion",
26
+ "created": int(datetime.datetime.now().timestamp()),
27
+ "model": model_name,
28
+ "choices": [{
29
+ "index": 0,
30
+ "message": {
31
+ "role": "assistant",
32
+ "content": response_text
33
+ },
34
+ "finish_reason": "stop"
35
+ }],
36
+ "usage": {
37
+ "prompt_tokens": prompt_tokens,
38
+ "completion_tokens": completion_tokens,
39
+ "total_tokens": prompt_tokens + completion_tokens
 
40
  }
41
+ }
 
42
 
43
+ @app.post("/v1/chat/completions")
44
+ async def chat_completion(request: Request):
45
  try:
46
+ data = await request.json()
47
+ messages = data.get("messages", [])
48
+
49
+ # Format messages for Qwen
50
+ conversation = []
51
+ for msg in messages:
52
+ conversation.append({
53
+ "role": msg["role"],
54
+ "content": msg["content"]
55
+ })
56
 
57
+ # Convert messages to model input format
58
+ prompt = tokenizer.apply_chat_template(
59
+ conversation,
60
+ tokenize=False,
61
+ add_generation_prompt=True
 
62
  )
63
 
64
+ # Count prompt tokens
65
+ prompt_tokens = len(tokenizer.encode(prompt))
66
+
67
+ # Generate response
68
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
69
+ outputs = model.generate(
70
+ **inputs,
71
+ max_new_tokens=data.get("max_tokens", 2048),
72
+ temperature=data.get("temperature", 0.7),
73
+ top_p=data.get("top_p", 0.95),
74
+ do_sample=True
75
  )
76
 
77
+ response_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
78
+ completion_tokens = len(tokenizer.encode(response_text))
79
+
80
+ return JSONResponse(
81
+ content=format_chat_response(response_text, prompt_tokens, completion_tokens)
82
+ )
83
  except Exception as e:
84
+ return JSONResponse(
85
+ status_code=500,
86
+ content={"error": str(e)}
87
+ )
88
 
89
+ # Gradio interface for testing
90
+ def chat_interface(message, history):
91
+ history = history or []
92
+ messages = [{"role": "user", "content": message}]
93
+
94
+ # Add history to messages
95
+ for h in history:
96
+ messages.insert(0, {"role": "assistant" if i % 2 else "user", "content": h[1 if i % 2 else 0]}
97
+ for i in range(len(h)))
98
+
99
+ response = chat_completion(Request({"messages": messages}))
100
+ return response.choices[0].message.content
101
+
102
+ interface = gr.ChatInterface(
103
+ chat_interface,
104
+ title="Qwen2.5-Coder-32B Chat",
105
+ description="Chat with Qwen2.5-Coder-32B model. This Space also provides a /v1/chat/completions endpoint."
106
+ )
107
 
108
+ # Mount both FastAPI and Gradio
109
+ app = gr.mount_gradio_app(app, interface, path="/")