astro21 commited on
Commit
27367c2
·
verified ·
1 Parent(s): 64c1f09

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -19
main.py CHANGED
@@ -1,6 +1,7 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- import transformers
 
4
  from fastapi.middleware.cors import CORSMiddleware
5
  import os
6
  from huggingface_hub import login
@@ -22,23 +23,12 @@ app.add_middleware(
22
  allow_headers=["*"],
23
  )
24
 
25
- # Load the model and tokenizer from Hugging Face, set device to CPU
26
- model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" # Replace with an appropriate model
27
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
28
- model = transformers.AutoModelForCausalLM.from_pretrained(
29
- model_id,
30
- # Removed device_map and low_cpu_mem_usage to avoid the need for 'accelerate'
31
- )
32
 
33
- # Set up the text generation pipeline for CPU
34
- pipeline = transformers.pipeline(
35
- "text-generation",
36
- model=model,
37
- tokenizer=tokenizer,
38
- max_new_tokens=150,
39
- temperature=0.7,
40
- device=-1 # Force CPU usage
41
- )
42
 
43
  # Define the request model for email input
44
  class EmailRequest(BaseModel):
@@ -47,11 +37,38 @@ class EmailRequest(BaseModel):
47
  recipients: str
48
  body: str
49
 
50
- # Helper function to create the email prompt
51
  def create_email_prompt(subject, sender, recipients, body):
52
- prompt = f"Subject: {subject}\nFrom: {sender}\nTo: {recipients}\n\n{body}\n\nSummarize this email."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return prompt
54
 
 
 
 
 
 
55
  # Define the FastAPI endpoint for email summarization
56
  @app.post("/summarize-email/")
57
  async def summarize_email(email: EmailRequest):
@@ -61,3 +78,4 @@ async def summarize_email(email: EmailRequest):
61
  summary = pipeline(prompt)[0]["generated_text"]
62
 
63
  return {"summary": summary}
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ import torch
4
+ from transformers import pipeline
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import os
7
  from huggingface_hub import login
 
23
  allow_headers=["*"],
24
  )
25
 
 
 
 
 
 
 
 
26
 
27
+
28
+
29
+
30
+ pipe = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map="auto")
31
+
 
 
 
 
32
 
33
  # Define the request model for email input
34
  class EmailRequest(BaseModel):
 
37
  recipients: str
38
  body: str
39
 
40
+
41
  def create_email_prompt(subject, sender, recipients, body):
42
+ messages = [
43
+ {
44
+ "role": "system",
45
+ "content": "You are an email summarizer. Your goal is to provide a concise summary by focusing on key points, action items, and urgency."
46
+ },
47
+ {
48
+ "role": "user",
49
+ "content": f"""
50
+ Summarize the following email by focusing on the key points, action items, and urgency.
51
+
52
+ Email Details:
53
+ Subject: {subject}
54
+ Sender: {sender}
55
+ Recipients: {recipients}
56
+
57
+ Body:
58
+ {body}
59
+
60
+ Provide a concise summary that includes important information, if any actions are required, and the priority of the email.
61
+ """
62
+ }
63
+ ]
64
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
65
  return prompt
66
 
67
+
68
+
69
+
70
+
71
+
72
  # Define the FastAPI endpoint for email summarization
73
  @app.post("/summarize-email/")
74
  async def summarize_email(email: EmailRequest):
 
78
  summary = pipeline(prompt)[0]["generated_text"]
79
 
80
  return {"summary": summary}
81
+