-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfile.py
More file actions
114 lines (94 loc) · 3.53 KB
/
Copy pathfile.py
File metadata and controls
114 lines (94 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from pathlib import Path
import re, unicodedata
import nltk
nltk.download("punkt")
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from math import ceil
from sklearn.cluster import SpectralClustering
from concurrent.futures import ThreadPoolExecutor
from transformers import pipeline
# Device selection: MPS if available, else CPU
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# Initialize models
tok_sum = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
encoder = SentenceTransformer("all-MiniLM-L6-v2") # CPU
summarizer = pipeline(
"summarization",
model="facebook/bart-large-cnn",
# No device arg -> CPU
batch_size=4
)
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf").to(DEVICE)
# Utility functions
def load_text(path: Path) -> str:
raw = path.read_text(encoding="utf-8", errors="ignore")
norm = unicodedata.normalize("NFKC", raw)
return re.sub(r"\s+", " ", norm).strip()
def sentence_blocks(text: str, limit=512):
sents = nltk.sent_tokenize(text)
block, blocks, start_idx = [], [], 0
for sent in sents:
tentative = " ".join(block + [sent])
enc = tok_sum.encode(tentative)
if len(enc) > limit and block:
blocks.append((" ".join(block), start_idx))
start_idx += len(block)
block = [sent]
else:
block.append(sent)
if block:
blocks.append((" ".join(block), start_idx))
return blocks
def similarity_graph(block_texts):
with torch.inference_mode():
emb = encoder.encode(
block_texts,
batch_size=64,
convert_to_tensor=True
)
return cosine_similarity(emb.cpu())
def cluster_blocks(sim_matrix, block_tokens, target=450):
total = sum(block_tokens)
n_clusters = ceil(total / target)
sc = SpectralClustering(
n_clusters=n_clusters,
affinity="precomputed",
assign_labels="discretize",
random_state=0
)
return sc.fit_predict(sim_matrix)
def compress_cluster(text):
return summarizer(
text,
max_length=256,
min_length=64,
do_sample=False
)[0]["summary_text"]
def summarise_chunks(blocks, labels):
clusters = {}
for (txt, pos), lab in zip(blocks, labels):
clusters.setdefault(lab, []).append((pos, txt))
ordered = [" ".join(t for _, t in sorted(v)) for v in (clusters[k] for k in sorted(clusters))]
with ThreadPoolExecutor(max_workers=4) as pool:
return list(pool.map(compress_cluster, ordered))
# Main pipeline
if __name__ == "__main__":
path = Path("input.txt")
text = load_text(path)
blocks = sentence_blocks(text)
block_texts = [txt for txt, _ in blocks]
block_tokens = [len(tok_sum.encode(t)) for t in block_texts]
sim_matrix = similarity_graph(block_texts)
labels = cluster_blocks(sim_matrix, block_tokens)
compressed_chunks = summarise_chunks(blocks, labels)
compressed_text = "\n\n".join(compressed_chunks)
question = "Summarize the protagonist's character arc."
prompt = f"<s>[SYSTEM]\n{compressed_text}\n[/SYSTEM]\n[USER]\n{question}\n[/USER]\n"
inputs = tok(prompt, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
out = model.generate(**inputs, max_new_tokens=512)
print(tok.decode(out[0], skip_special_tokens=True))